# coding:utf-8
'''
author:wangyi
'''
import os
import thulac
from collections import defaultdict
import pickle
'''
1.分词 + 整理标签
2.标签和word与id的映射
'''


# 分词&&对应标签整合
# base_dir = './THUCNews'
#
# 停用词？？
def cut_words_and_label(base_dir,out_dir):

    thu = thulac.thulac(seg_only=True,filt=True)
    # [‘娱乐’,'体育',...]
    folders = os.listdir(base_dir)
    # 训练集、验证集、测试集
    train_writer = open(os.path.join(out_dir,'train.txt'),'w',encoding='utf-8')
    val_writer = open(os.path.join(out_dir,'val.txt'),'w',encoding='utf-8')
    test_writer = open(os.path.join(out_dir,'test.txt'),'w',encoding='utf-8')
    for folder in folders:
        # folder='体育'
        # files的路径 = './THUCNews/体育'
        # files = ['111.txt','123.txt',...,]
        files = os.listdir(os.path.join(base_dir,folder))
        for i,file in enumerate(files):
            with open(os.path.join(os.path.join(base_dir,folder),file),encoding='utf-8') as f:
                for line in f.readlines():
                    if i<int(len(files)*0.7):
                        if thu.cut(line,text=True) != '':
                            train_writer.write(folder+'\t'+thu.cut(line,text=True)+'\n')
                            train_writer.flush()
                    elif i < int(len(files)*0.8):
                        if thu.cut(line, text=True) != '':
                            val_writer.write(folder+'\t'+thu.cut(line,text=True)+'\n')
                            val_writer.flush()
                    else:
                        if thu.cut(line, text=True) != '':
                            test_writer.write(folder+'\t'+thu.cut(line,text=True)+'\n')
                            test_writer.flush()
            print(folder,'第',i+1,'条已写入！')
    train_writer.close()
    val_writer.close()
    test_writer.close()



# base_dir 包含了['train.txt','val.txt','test.txt']
def create_word2id(base_dir,vocab_size,out_path):
    '''
    :param base_dir: 数据源
    :param vocab_size: 词表大小
    :param out_path: 输出路径
    :return: 
    '''
    labels = set()
    files = os.listdir(base_dir)
    word2count = defaultdict(int)
    for file in files:
        with open(os.path.join(base_dir,file)) as f:
            for line in f.readlines():
                labels.add(line.split('\t')[0])
                words = line.split('\t')[1].strip().split(' ')
                for word in words:
                    # 统计词频
                    word2count[word] += 1
    label2id = {label:i for i,label in enumerate(list(labels))}
    # x 从 word2count.items() x=(key,value) x[1] = value 词频
    # lambda f(x) --> after冒号的东西
    # 按照词频降序排序
    word2count = dict(sorted(word2count.items(),key=lambda x:x[1],reverse=True))
    # word2count = {(’我‘:1000),(’中国‘:200)}
    # word2count.keys()= ['我','中国',...]
    # word2id = {’我‘:1,’中国‘:2,....}
    word2id = {(list(word2count.keys())[i]):i+1 for i in range(vocab_size-1)}
    word2id['<PAD>'] = 0
    out_writer = open(out_path,'wb')
    pickle.dump([label2id,word2id],out_writer)
    out_writer.close()

if __name__ == '__main__':

    cut_words_and_label('./THUCNews','./datasets')
    #create_word2id('./datasets',2000,'./datasets/data2id.pkl')





